import matplotlib.pyplot as plt
import numpy as np
from sklearn.datasets import load_diabetes
from sklearn.model_selection import train_test_split
from sklearn import linear_model

def model(X_train, y_train, X_test, y_test):
    fig, axs = plt.subplots(2,3, figsize=(16,7))
    TRAINING = 0; TESTING = 1;
    for set_type, label in zip((TRAINING, TESTING), 
                               ('Training set', 'Testing set')):
        for alpha, column in zip([0.01,0.3,1], [0,1,2]):
            lasso = linear_model.Lasso(alpha=alpha)
            lasso.fit(X_train, y_train)
            if set_type == TRAINING:
                # Training accuracy
                training_score = lasso.score(X_train, y_train)   
                y_pred = lasso.predict(X_train)
                plot(axs, X_train, y_train, y_pred, alpha, set_type,
                  column, label+', score: {:3.2f}'.format(training_score))
            else:
                # Testing accuracy
                testing_score = lasso.score(X_test, y_test)      
                y_pred = lasso.predict(X_test)
                plot(axs, X_test, y_test, y_pred, alpha, set_type, column,
                     label+', score: {:3.2f}'.format(testing_score)) 
    plt.show()
    return

def plot(axs, data2D, target1D, predict1D, alpha, 
         set_type, column, label):
    axs[set_type,column].plot(data2D, target1D, 'y.', markersize=6)
    axs[set_type,column].plot(data2D, predict1D, 'k-')
    axs[set_type,column].set_title('Lasso, alpha:' + str(alpha)+', ' 
                                   + label)
    axs[set_type,column].grid()
    return

diabetes = load_diabetes()
X = diabetes.data
y = diabetes.target

BLOOD_PRESSURE = 3; S3 = 6;
for feature in (BLOOD_PRESSURE, S3):
    X_train, X_test, y_train, y_test = train_test_split(X[:,feature], y,
                                                        random_state = 60)
    # Regression fit methods need a data matrix as argument
    X_train = X_train[:, np.newaxis]
    # The Score method needs a data matrix as argument    
    X_test = X_test[:, np.newaxis]      
    model(X_train, y_train, X_test, y_test)
    print()
